{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "xla_compile.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "f4TSNCvpENrW"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "vamNSA0vEP-m",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "e1oSi4lHFt3z"
      },
      "source": [
        "# XLAコンパイラAPI"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "b7noD9NjFRL-"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/xla_compile\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "v9YbsuLZaBXy"
      },
      "source": [
        "\n",
        "\n",
        "TensorFlowとXLAライブラリをインポートします。XLAには、一部または全てのモデルを [XLA](https://www.tensorflow.org/extend/xla/) でコンパイルする実験的なAPIである `xla.compile()` が含まれています。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "45kUPj5ZFrRa",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow.contrib.compiler import xla"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GZVNiRmTDV-5"
      },
      "source": [
        "必要ないくつかの定数を定義し、 MNISTのデータセットを用意します。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "f37TSEGvGX4_",
        "colab": {}
      },
      "source": [
        "# それぞれの入力イメージの大きさは、 28 x 28ピクセル\n",
        "IMAGE_SIZE = 28 * 28\n",
        "# 個別の数字のラベル [0..9] の個数\n",
        "NUM_CLASSES = 10\n",
        "# それぞれのトレーニングバッチ（ステップ）での標本数\n",
        "TRAIN_BATCH_SIZE = 100\n",
        "# トレーニングステップを実行する回数\n",
        "TRAIN_STEPS = 1000"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "TiVXchblG5hK",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 173
        },
        "outputId": "e3cd80ce-3809-4c08-f7ed-3a2d9f385f5d"
      },
      "source": [
        "# MNISTデータセットをロードする。\n",
        "train, test = tf.keras.datasets.mnist.load_data()\n",
        "train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()\n",
        "test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)\n",
        "\n",
        "iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)\n",
        "images, labels = iterator.get_next()\n",
        "images = tf.reshape(images, [-1, IMAGE_SIZE])\n",
        "images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
            "11493376/11490434 [==============================] - 0s 0us/step\n",
            "WARNING:tensorflow:From <ipython-input-4-4b4b30f2fbb2>:5: DatasetV1.output_types (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "Use `tf.compat.v1.data.get_output_types(dataset)`.\n",
            "WARNING:tensorflow:From <ipython-input-4-4b4b30f2fbb2>:5: DatasetV1.output_shapes (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "Use `tf.compat.v1.data.get_output_shapes(dataset)`.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "x_ZehpZP-SfS"
      },
      "source": [
        "# モデルを構築する関数の定義\n",
        "\n",
        "以下のコードブロックは、順伝搬と逆伝搬の両方を行う、１つのdenseレイヤーを持つ簡単なモデルを構築する関数を含みます。\n",
        "\n",
        "コードが呼ばれたとき、２つの値を返します。 `y` は、それぞれのターゲットのクラスの予測確率を表す `tf.Tensor` です。 `train_step` は `global_step` の値を増加し、変数の更新を行う `tf.Operation` です。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ZbhJl_WvGa3g",
        "colab": {}
      },
      "source": [
        "def build_mnist_model(x, y_):\n",
        "  y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)\n",
        "\n",
        "  cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)\n",
        "  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)\n",
        "\n",
        "  return y, train_step"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7Jh3lyQHDfM9"
      },
      "source": [
        "# XLA の有効化\n",
        "\n",
        "XLA を有効化するには `build_mnist_model` 関数を `xla.compile` に渡します。以下のコードブロックは、モデルを `xla.compile()` 関数でラップします。これにより、提供された入力を持つターゲット関数をXLAで実行できます。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "kYpCXCdRHNuN",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 139
        },
        "outputId": "203aa05d-c54e-404a-eced-4bbfe4a44d5b"
      },
      "source": [
        "[y] = xla.compile(build_mnist_model, inputs=[images, labels])"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
            "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/losses/losses_impl.py:121: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "4giQh62IrZGF"
      },
      "source": [
        "グラフをコンパイルするとき、XLAはターゲット関数によって構築されたグラフの全てのノードを、いくつかのXLAのオペレータで置き換えます。\n",
        "\n",
        "xla.compileは、生成されたXLAのオペレータから独立して実行できる `tf.Operation` を返しません\n",
        "代わりに、ターゲット関数から返された `tf.Operation` ノードは、返された全ての `tf.Tensor` の値との制御依存関係として追加されます。これにより、 返されたテンソルが評価されるときに、 `tf.Operation` ノードの実行をトリガします。\n",
        "\n",
        " 擬似コードによるxla.compileの実装は、以下のようになります：\n",
        "\n",
        "---\n",
        "```\n",
        "# TensorFlowに、XLAが扱いやすい方法でコードを実行するよう依頼する\n",
        "\n",
        "y, train_step = build_mnist_model(images, labels)\n",
        "with tf.control_dependencies([train_step]):\n",
        "  y = tf.identity(y)\n",
        "\n",
        "# TensorFlowに、XLAが扱いやすい方法でコードの実行を停止するよう依頼する\n",
        "```\n",
        "---\n",
        "\n",
        "xla.compile()は常に `tf.Tensor` のリスト（１要素しか無かったとしても）を返します。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TPGas4jjFLZl"
      },
      "source": [
        "もしあなたが構築したグラフを今表示したら、通常のTensorFlowのグラフとそれほど変わらないことがわかり、前に述べたXLAのオペレータを見つけることができないでしょう。これは、あなたが `sess.run()` でグラフを実行しようとしても、実際のコンパイルは後ほど発生するからです。後ほど、TensorFlowは実際にXLAオペレータを生成する一連のグラフ書き換えパスをトリガーします。これは、すべての入力がそろったときに、計算をコンパイルして実行します。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "EZD1m_n1DxAF"
      },
      "source": [
        "# モデルの学習とテスト"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "qe28bAHNHUG2",
        "colab": {}
      },
      "source": [
        "# セッションを作成しすべての変数を初期化。\n",
        "# xla.compile()は、Keras model.fit() APIやTF eager modeとはまだ動作しません。\n",
        "sess = tf.Session()\n",
        "sess.run(tf.global_variables_initializer())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qgsKmz3n2UiW"
      },
      "source": [
        "以下のコードブロックはモデルを学習します。 `y` の評価は、制御依存関係がある `train_step` をトリガします。これは、モデル変数を更新します。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "_GxF6jTRHVuA",
        "outputId": "926558f2-5792-4eab-a9fc-58e19e409aea",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 207
        }
      },
      "source": [
        "# 学習用データセットを与える\n",
        "sess.run(iterator.make_initializer(train_ds))\n",
        "\n",
        "# TRAIN_STEPS ステップだけ実行する\n",
        "for i in range(TRAIN_STEPS):\n",
        "  sess.run(y)\n",
        "\n",
        "print(\"Model trained for %s steps.\" % TRAIN_STEPS)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/iterator_ops.py:348: Iterator.output_types (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "Use `tf.compat.v1.data.get_output_types(iterator)`.\n",
            "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/iterator_ops.py:349: Iterator.output_shapes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "Use `tf.compat.v1.data.get_output_shapes(iterator)`.\n",
            "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/iterator_ops.py:351: Iterator.output_classes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "Use `tf.compat.v1.data.get_output_classes(iterator)`.\n",
            "Model trained for 1000 steps.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "dHlQlRSRHXD1",
        "outputId": "588448d9-93e0-401c-e2de-be754ae140ca",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "# 学習済みモデルをテストする\n",
        "\n",
        "# テスト用データセットを与える\n",
        "sess.run(iterator.make_initializer(test_ds))\n",
        "\n",
        "# 精度を計算する\n",
        "correct_prediction = tf.equal(tf.argmax(y, 1), labels)\n",
        "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
        "print(\"Prediction accuracy after training: %s\" % sess.run(accuracy))"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Prediction accuracy after training: 0.91\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ynJQIuzjHYOb",
        "colab": {}
      },
      "source": [
        "# セッションを片付ける\n",
        "sess.close()"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}